-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added train by epoch for Trainer and added support for texts #12
base: main
Are you sure you want to change the base?
Conversation
this is awesome Marcus! will take a look at it tomorrow morning and get it merged! |
Awesome work, thanks for including your suggestions to the main, this allows better understanding on the user's side. |
meshgpt_pytorch/meshgpt_pytorch.py
Outdated
@@ -741,6 +741,7 @@ def forward( | |||
vertices: TensorType['b', 'nv', 3, float], | |||
faces: TensorType['b', 'nf', 3, int], | |||
face_edges: Optional[TensorType['b', 'e', 2, int]] = None, | |||
texts: Optional[List[str]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, so the text is actually only conditioned through the transformer stage through cross attention
basically the autoencoder is given the job of only compressing meshes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I know :) But if you pass it a dict with texts it will give a error since the arg doesnt exist.
So then you would need two dataset classes.
Either replace the model(**forward_args) so it uses the prarameters directly:
model(vertices = data["vertices"], faces = data["faces"])
Or just implement a dummy texts :) There is probably a better solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh got it! yea, i can take care of that within the trainer class (just scrub out the text
and text_embed
keys)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that will work, I'm not 100% since the dataloader passes the data and maybe copies it(?).
But it won't work if you access it without copying it since the dataset is returning the data and not copying/cloning, when you do del on a key, it will remove it completely from the dataset.
So if you train the encoder and then want to train a transformer, you'll need to recreate the dataset since the texts key is removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer if the dataset returns the text with each vertices and faces.
meshgpt_pytorch/trainer.py
Outdated
@@ -367,7 +370,63 @@ def forward(self): | |||
self.wait() | |||
|
|||
self.print('training complete') | |||
def train(self, num_epochs, diplay_graph = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small typo here diplay
|
||
|
||
self.print('Training complete') | ||
if diplay_graph: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so i haven't documented this, but you can already use wandb.ai experiment tracker
you just have to do
trainer = Trainer(..., use_wandb_tracking = True)
with trainer.trackers('meshgpt', 'one-experiment-name'):
trainer.train()
Btw since I don't really think grad_accum_every is very useful I removed it from the train function, what is your option? I forgot and left grad_accum_every in the loss function, so if it wont be used in the train function it should be removed from:
|
i'm sure researchers will want to stretch to the next level if this approach pans out (multiple meshes, scenes etc) probably good to keep it for the gpu poor |
Another thing :) I'm not very experienced in using github forks but it seems like the pull request added later commits then when I made the request. I made bit of a error and replaced entire meshgpt_pytorch.py since there was some weird stash thing and I wanted to ensure it was up to date. I reverted but it seems like that stash thing messed it up bit, please double check if this is the case |
Update mesh_dataset.py from entrys to entries
Hi! Thanks for your amazing job! May I ask why in MeshGPT_dem.ipynb has a |
Trainer
- Added option to display graph (maybe remove since it requires matlib?)
Data.py
MeshAutoencoder
Setup.py